Graph Data Embedding¶
  • NetworkX for creating the graph.
    • Create the sample graph: Nodes are classified as either "Birds" or "Timber", with edges representing the "origin" relationships.
    • create_sample_graph: This function creates a sample graph with two categories: "Birds" (e.g., eagle, parrot, sparrow) and "Timber" (e.g., oak, cedar, maple), with edges representing the "origin" relationship.
  • Transformers from Hugging Face for generating embeddings.
    • Generate embeddings: Use a transformer-based model (like BERT or a domain-specific model) to create embeddings for each node in the graph.
      • get_bert_embeddings: This function generates embeddings using a pre-trained BERT model. The embeddings for each node (bird or timber) are created by tokenizing the node name and using the mean of the last hidden states.
      • create_embeddings_for_graph: This function generates embeddings for all nodes in the graph.
  • Faiss for storing and querying the embeddings.
    • Store embeddings in Faiss: The embeddings are indexed using Faiss for efficient similarity search.
    • User Query: The system will retrieve the top 3 most similar graph nodes based on the user's query.
    • store_embeddings_in_faiss: This function stores the embeddings in a FAISS index to enable efficient similarity search.
    • retrieve_top_k_similar: This function retrieves the top-k most similar nodes to a given query using FAISS.
  • Matplotlib and NetworkX for displaying graph structures.
    • Display Results: The original graph and the predicted graph data are displayed in tables, and the network is visualized.
In [ ]:
%pip install -q networkx faiss-cpu transformers pandas matplotlib torch
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
petastorm 0.12.1 requires pyspark>=2.1.0, which is not installed.
databricks-feature-store 0.14.3 requires pyspark<4,>=3.1.2, which is not installed.
ydata-profiling 4.2.0 requires numpy<1.24,>=1.16.0, but you have numpy 2.1.3 which is incompatible.
scipy 1.9.1 requires numpy<1.25.0,>=1.18.5, but you have numpy 2.1.3 which is incompatible.
numba 0.55.1 requires numpy<1.22,>=1.18, but you have numpy 2.1.3 which is incompatible.
mleap 0.20.0 requires scikit-learn<0.23.0,>=0.22.0, but you have scikit-learn 1.1.1 which is incompatible.
langchain 0.0.217 requires numpy<2,>=1, but you have numpy 2.1.3 which is incompatible.
databricks-feature-store 0.14.3 requires numpy<2,>=1.19.2, but you have numpy 2.1.3 which is incompatible.
Note: you may need to restart the kernel using dbutils.library.restartPython() to use updated packages.
In [ ]:
import networkx as nx
import numpy as np
import pandas as pd
import faiss
from transformers import BertTokenizer, BertModel
import torch
import matplotlib.pyplot as plt
import textwrap

# 1. Create the sample graph data with two categories (Birds and Timber)
def create_sample_graph():
    G = nx.Graph()

    # Add bird nodes and their origin edges
    birds = ["eagle", "parrot", "sparrow", "emu"]
    bird_origins = {"eagle": "USA", "parrot": "Australia", "sparrow": "Europe", "emu": "Australia"}
    for bird in birds:
        G.add_node(bird, category="bird", origin=bird_origins[bird])
    
    # Add timber nodes and their origin edges
    timbers = ["oak", "cedar", "maple", "tasmanian-oak"]
    timber_origins = {"oak": "USA", "cedar": "Canada", "maple": "Canada", "tasmanian-oak" : "Australia"}
    for timber in timbers:
        G.add_node(timber, category="timber", origin=timber_origins[timber])

    # Add edges representing origin relationships
    G.add_edge("eagle", "parrot", relation="origin")
    G.add_edge("parrot", "sparrow", relation="origin")
    G.add_edge("emu", "parrot", relation="origin")
    G.add_edge("oak", "cedar", relation="origin")
    G.add_edge("cedar", "maple", relation="origin")
    G.add_edge("tasmanian-oak", "oak", relation="origin")

    return G, bird_origins, timber_origins

# 2. Generate BERT-based embeddings for the graph nodes
def get_bert_embeddings(texts):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = BertModel.from_pretrained("bert-base-uncased")

    inputs = tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)

    # Use the mean of the last layer hidden states as embeddings
    embeddings = outputs.last_hidden_state.mean(dim=1)
    return embeddings.numpy()

# 3. Create embeddings for graph nodes
def create_embeddings_for_graph(G):
    node_names = list(G.nodes)
    node_texts = [f"{node} ({G.nodes[node]['category']} from {G.nodes[node]['origin']})" for node in node_names]

    # Generate embeddings for each node using BERT
    embeddings = get_bert_embeddings(node_texts)
    
    # Store embeddings in a DataFrame for easy inspection
    embedding_df = pd.DataFrame(embeddings, index=node_names, columns=[f"dim_{i}" for i in range(embeddings.shape[1])])
    return embedding_df, embeddings

# 4. Store embeddings in FAISS index for similarity search
def store_embeddings_in_faiss(embeddings):
    dim = embeddings.shape[1]  # Dimensionality of embeddings
    index = faiss.IndexFlatL2(dim)  # Use L2 distance metric
    index.add(embeddings.astype(np.float32))  # Add embeddings to FAISS index
    return index

# 5. Retrieve the top k most similar nodes based on a query
def retrieve_top_k_similar(index, query_embedding, k=2):
    D, I = index.search(query_embedding.astype(np.float32), k)
    return I, D

# Function to wrap text inside nodes for better visualization
def wrap_text(label, width=10):
    return "\n".join(textwrap.wrap(label, width))

# 6. Visualize the graph using NetworkX (with edges showing the origin information)
def visualize_graph_side_by_side(G, birds, timbers, title="Graph"):
    fig, axes = plt.subplots(1, 2, figsize=(14, 7))  # 1 row, 2 columns for side-by-side plots

    # Generate subgraphs for birds and timbers
    G_birds = G.subgraph(birds)
    G_timbers = G.subgraph(timbers)

    # First subplot for Birds
    ax1 = axes[0]
    pos_birds = nx.spring_layout(G_birds, seed=42)  # Generate positions for the subgraph

    # Wrap node labels for better readability
    labels_birds = {node: wrap_text(node) for node in G_birds.nodes()}

    # Draw the graph with wrapped labels
    nx.draw(G_birds, pos_birds, with_labels=True, labels=labels_birds, node_color="lightblue", font_weight="bold", node_size=3000, font_size=12, ax=ax1)
    edge_labels_birds = { (u, v): f"{G.nodes[u]['origin']} -> {G.nodes[v]['origin']}" for u, v in G_birds.edges() }
    nx.draw_networkx_edge_labels(G_birds, pos_birds, edge_labels=edge_labels_birds, font_size=12, font_color='red', ax=ax1)
    ax1.set_title("Birds")

    # Second subplot for Timbers
    ax2 = axes[1]
    pos_timbers = nx.spring_layout(G_timbers, seed=42)  # Generate positions for the subgraph

    # Wrap node labels for better readability
    labels_timbers = {node: wrap_text(node) for node in G_timbers.nodes()}

    # Draw the graph with wrapped labels
    nx.draw(G_timbers, pos_timbers, with_labels=True, labels=labels_timbers, node_color="lightgreen", font_weight="bold", node_size=3000, font_size=12, ax=ax2)
    edge_labels_timbers = { (u, v): f"{G.nodes[u]['origin']} -> {G.nodes[v]['origin']}" for u, v in G_timbers.edges() }
    nx.draw_networkx_edge_labels(G_timbers, pos_timbers, edge_labels=edge_labels_timbers, font_size=12, font_color='red', ax=ax2)
    ax2.set_title("Timbers")

    plt.tight_layout()
    plt.show()

# Main function to simulate the full process
def main():
    # 1. Create the sample graph
    G, bird_origins, timber_origins = create_sample_graph()
    
    # 2. Generate embeddings for the graph
    embedding_df, embeddings = create_embeddings_for_graph(G)
    
    # Display original graph data with first 2 dimensions of embeddings
    print("Graph Data (first 2 embeddings of nodes):")
    
    # Create a DataFrame to display the graph edges and their relations
    original_df = pd.DataFrame(list(G.edges(data=True)), columns=["Node1", "Node2", "Relation"])
    
    # Map origin information to the Relation column based on nodes
    def get_relation_value(node1, node2):
        # Check if both nodes are from the bird category or timber category
        if node1 in bird_origins and node2 in bird_origins:
            return bird_origins.get(node1, "Unknown") + " -> " + bird_origins.get(node2, "Unknown")
        elif node1 in timber_origins and node2 in timber_origins:
            return timber_origins.get(node1, "Unknown") + " -> " + timber_origins.get(node2, "Unknown")
        else:
            return "Unknown"
    
    # Apply the get_relation_value function to each edge
    original_df["Relation"] = original_df.apply(lambda row: get_relation_value(row["Node1"], row["Node2"]), axis=1)

    # Add only the first 2 dimensions of embeddings to original_df
    node_embeddings = {node: embeddings[i][:2] for i, node in enumerate(G.nodes)}

    # Then, for each row in original_df, map the nodes to their embeddings
    original_df['Node1 Embedding'] = original_df['Node1'].map(node_embeddings)
    original_df['Node2 Embedding'] = original_df['Node2'].map(node_embeddings)
    
    # Add the index as the first column in original_df
    original_df.reset_index(inplace=True)
    original_df.rename(columns={'index': 'Index'}, inplace=True)
    
    print(original_df.to_markdown(index=False))

    # 6. Visualize the graph (Birds and Timbers side by side)
    print("\nVisualising of Nodes and Edges")  # Birds and Timbers Graph...
    birds = [n for n, d in G.nodes(data=True) if d['category'] == 'bird']
    timbers = [n for n, d in G.nodes(data=True) if d['category'] == 'timber']
    visualize_graph_side_by_side(G, birds, timbers, title="Birds and Timbers")

    # 3. Store embeddings in FAISS
    index = store_embeddings_in_faiss(embeddings)
    
    # 4. Simulate a user query (e.g., querying for "parrot")
    user_query = "emu"
    query_embedding = get_bert_embeddings([user_query])
    
    # 5. Retrieve the top 3 similar nodes based on the user query
    indices, distances = retrieve_top_k_similar(index, query_embedding, k=3)
    
    # Retrieve the predicted nodes based on indices
    predicted_nodes = [list(G.nodes)[i] for i in indices[0]]
    
    # Display user query and predicted graph data (top 3 similar nodes)
    print(f"\nUser Query: {user_query} ")
    predicted_df = pd.DataFrame({
        "Index": indices[0],  # Include index in predicted_df
        "Predicted Node": predicted_nodes,
        "Embedding Dimension (first 2 embeddings)": [embeddings[i][:2] for i in indices[0]]  # Slice the first 5 dimensions
    })
    
    # 7. Add the Relation column to predicted_df based on origin information
    def get_predicted_relation(node):
        if node in bird_origins:
            return bird_origins[node]
        elif node in timber_origins:
            return timber_origins[node]
        else:
            return "Unknown"
    
    # Populate the 'Relation' column
    predicted_df['Relation'] = predicted_df['Predicted Node'].map(get_predicted_relation)
    
    # Sort predicted_df by 'Index' column in ascending order
    predicted_df = predicted_df.sort_values(by="Index", ascending=True)
    
    print(predicted_df.to_markdown(index=False))
    
    # 6. Visualize the predicted graph (subgraph of predicted nodes)
    print("\nVisualizing Predicted Node Graph")
    predicted_G = G.subgraph(predicted_nodes)
    visualize_graph_side_by_side(predicted_G, birds, timbers, title="Predicted Graph")

if __name__ == "__main__":
    main()
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Graph Data (first 2 embeddings of nodes):
|   Index | Node1   | Node2         | Relation               | Node1 Embedding           | Node2 Embedding           |
|--------:|:--------|:--------------|:-----------------------|:--------------------------|:--------------------------|
|       0 | eagle   | parrot        | USA -> Australia       | [-0.08763738 -0.18008912] | [-0.21244605 -0.08693982] |
|       1 | parrot  | sparrow       | Australia -> Europe    | [-0.21244605 -0.08693982] | [-0.30793023 -0.19978818] |
|       2 | parrot  | emu           | Australia -> Australia | [-0.21244605 -0.08693982] | [-0.28172386 -0.08433155] |
|       3 | oak     | cedar         | USA -> Canada          | [-0.19053563 -0.02666191] | [-0.15451434  0.1392495 ] |
|       4 | oak     | tasmanian-oak | USA -> Australia       | [-0.19053563 -0.02666191] | [-0.11889555  0.14518885] |
|       5 | cedar   | maple         | Canada -> Canada       | [-0.15451434  0.1392495 ] | [-0.17470792  0.18844433] |

Visualising of Nodes and Edges
No description has been provided for this image
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
User Query: emu 
|   Index | Predicted Node   | Embedding Dimension (first 2 embeddings)   | Relation   |
|--------:|:-----------------|:-------------------------------------------|:-----------|
|       1 | parrot           | [-0.21244605 -0.08693982]                  | Australia  |
|       3 | emu              | [-0.28172386 -0.08433155]                  | Australia  |
|       4 | oak              | [-0.19053563 -0.02666191]                  | USA        |

Visualizing Predicted Node Graph
No description has been provided for this image